/**
 * 2018年3月12日
 */
package cn.edu.bjtu.model.core;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;

import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.util.ModelSerializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import cn.edu.bjtu.api.NetworkModelService;
/**
 * @author Alex
 *
 */

public class ThreadModel extends ThreadLocal<ComputationGraph> {
	
	private Logger logger = LoggerFactory.getLogger(ThreadModel.class);
	private final byte [] res;
	
	public ThreadModel(NetworkModelService nms) {
		res = nms.getModelFileStream();
	}
	public ComputationGraph get() {
		if(super.get() == null){
			try {
				logger.info(" Thread {} loading model from stream ... " ,Thread.currentThread().getName());
				ComputationGraph cg = ModelSerializer.restoreComputationGraph(getStream());
				super.set(cg);
				return cg;
			} catch (IOException e) {
				logger.info(" Thread {} loading model exception {} set model null"
						+ " thread will try next time when needed " ,Thread.currentThread().getName(),e);
				super.set(null);
				return null;
			}
		}else{
			return super.get();
		}
	};
	InputStream getStream(){
		ByteArrayInputStream bais = new ByteArrayInputStream(res);
		return bais;
	}
	
}

